Mech Interp of Binarized Neural Networks¶
Question being explored: Recent papers have shown that binary and ternary transformer based networks with weights of {-1,1} or {-1,0,1} can achieve similar results to full precision networks. Are these networks simply simulating a full precision network or are they learning different and possibly more interpretable algorithms due to their discretized nature.
Setup: A 1 layer transformer with all weights binarized except for the embed and unembed. The specific implementation is based off of the BitNet paper and code is in the BitNet folder.
Findings: It seems like for modular addition, binary transformers exhibit grokking in a very similar way and seem to be learning fundamentally the same algorithm and is more or less just emulating a full precision network. More analysis at the end.
Setup¶
(No need to read)
TRAIN_MODEL = False
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = True
try:
import google.colab
IN_COLAB = True
print("Running as a Colab notebook")
%pip install transformer-lens
%pip install circuitsvis
# PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
# # Install another version of node that makes PySvelte work way faster
# !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
# %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
IN_COLAB = False
print("Running as a Jupyter notebook - intended for development only!")
from IPython import get_ipython
ipython = get_ipython()
# Code to automatically update the HookedTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")
Running as a Jupyter notebook - intended for development only!
/tmp/ipykernel_117339/398382186.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).
ipython.magic("load_ext autoreload")
/tmp/ipykernel_117339/398382186.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).
ipython.magic("autoreload 2")
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
pio.renderers.default = "colab"
else:
pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")
Using renderer: notebook_connected
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader
from typing import List, Union, Optional
from functools import partial
import copy
import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
/home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
HookedRootModule,
HookPoint,
) # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
/home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
Plotting helper functions:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
x = utils.to_numpy(x)
y = utils.to_numpy(y)
px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"
# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)
Model Training¶
Config¶
p = 113
frac_train = 0.3
# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.9, 0.98)
num_epochs = 25000
checkpoint_every = 100
DATA_SEED = 598
Define Task¶
- Define modular addition
- Define the dataset & labels
Input format: |a|b|=|
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)
equals_vector = einops.repeat(torch.tensor(113), " -> (i j)", i=p, j=p)
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).cuda()
print(dataset[:5])
print(dataset.shape)
tensor([[ 0, 0, 113],
[ 0, 1, 113],
[ 0, 2, 113],
[ 0, 3, 113],
[ 0, 4, 113]], device='cuda:0')
torch.Size([12769, 3])
labels = (dataset[:, 0] + dataset[:, 1]) % p
print(labels.shape)
print(labels[:5])
torch.Size([12769]) tensor([0, 1, 2, 3, 4], device='cuda:0')
Convert this to a train + test set - 30% in the training set
torch.manual_seed(DATA_SEED)
indices = torch.randperm(p*p)
cutoff = int(p*p*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]
train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)
tensor([[ 21, 31, 113],
[ 30, 98, 113],
[ 47, 10, 113],
[ 86, 21, 113],
[ 99, 83, 113]], device='cuda:0')
tensor([ 52, 15, 57, 107, 69], device='cuda:0')
torch.Size([3830, 3])
tensor([[ 43, 40, 113],
[ 31, 42, 113],
[ 39, 63, 113],
[ 35, 61, 113],
[112, 102, 113]], device='cuda:0')
tensor([ 83, 73, 102, 96, 101], device='cuda:0')
torch.Size([8939, 3])
Define Model¶
cfg = HookedTransformerConfig(
n_layers = 1,
n_heads = 4,
d_model = 128,
d_head = 32,
d_mlp = 512,
act_fn = "relu",
normalization_type=None,
d_vocab=p+1,
d_vocab_out=p,
n_ctx=3,
init_weights=True,
device="cuda",
seed = 999,
)
model = HookedTransformer(cfg)
from bitnet import BitNetTransformer
model2 = BitNetTransformer(dim=128, depth=1, heads=4, in_features=p+1, out_features=p, random_seed=999).to("cuda")
tensor([[ 0., 0., 0., 0., -0., -1., 0., -1., 1., 1., -1., 0., -1., 0.,
-1., 1., -0., 1., 0., -1.],
[ 1., -0., 1., -0., -1., 1., -1., -1., 1., -1., -1., -1., -1., 1.,
-1., 0., 1., 1., 1., -1.],
[-1., -0., -1., -1., 0., 0., -1., 1., -1., -1., 0., -1., -0., -0.,
1., -0., -0., -1., -1., 1.],
[-1., 1., -1., -1., 1., -1., 1., 1., -0., -1., -1., -1., -1., 1.,
0., -1., 1., 0., -1., 1.],
[ 1., 1., -1., 1., -1., 1., -1., -0., 0., -1., 1., -0., -0., 1.,
-0., -1., 1., -1., 1., -0.],
[-1., -1., 1., -1., -0., -1., 1., 1., -1., -1., -0., -1., -1., 0.,
1., -1., -1., 0., 0., -0.],
[ 0., -1., 1., -1., 1., -1., -1., 0., 1., -0., -1., 0., -1., -1.,
-1., 1., 1., 1., 0., 1.],
[-1., -1., -1., 1., 0., 1., -0., 1., -1., -1., 1., -1., 1., 0.,
1., -0., -0., -1., -1., -0.],
[ 1., 1., -1., 0., 1., 1., -1., 1., 1., 1., -1., 1., 1., 1.,
-1., -1., 1., -1., -1., 1.],
[-0., -1., 0., -0., 1., -1., 1., -1., -0., 1., -1., 0., -0., -1.,
0., 0., -0., 1., 1., -1.],
[-0., 1., -0., 1., 1., -0., 1., 1., -0., 0., -1., 1., 1., -0.,
0., -1., 0., -1., -1., 1.],
[ 1., 1., -0., -0., 0., -0., -1., -1., 1., 1., -1., 1., -0., 1.,
-1., -1., 1., 1., -1., 1.],
[-1., 0., 1., 1., -1., 1., 1., 1., -1., -0., 1., 0., 1., 0.,
1., -1., -1., -1., 0., -1.],
[ 1., 1., -1., 1., 1., 0., -1., -0., 1., 1., 0., 1., -0., 1.,
-1., 1., 1., 1., -1., 1.],
[ 0., 0., 1., -1., 0., -1., 1., -1., 0., -1., -1., -1., -0., 1.,
-0., -1., 1., 1., 1., -1.],
[-1., -0., 0., -0., 0., -1., 0., 1., -1., -1., 0., -1., 0., -0.,
1., -0., -1., -0., -1., 1.]], grad_fn=<ClampBackward1>)
tensor([[ 0., 0., 0., 0., -0., -1., 0., -1., 1., 1., -1., 0., -1., 0.,
-1., 1., -0., 1., 0., -1.],
[ 1., -0., 1., -0., -1., 1., -1., -1., 1., -1., -1., -1., -1., 1.,
-1., 0., 1., 1., 1., -1.],
[-1., -0., -1., -1., 0., 0., -1., 1., -1., -1., 0., -1., -0., -0.,
1., -0., -0., -1., -1., 1.],
[-1., 1., -1., -1., 1., -1., 1., 1., -0., -1., -1., -1., -1., 1.,
0., -1., 1., 0., -1., 1.],
[ 1., 1., -1., 1., -1., 1., -1., -0., 0., -1., 1., -0., -0., 1.,
-0., -1., 1., -1., 1., -0.],
[-1., -1., 1., -1., -0., -1., 1., 1., -1., -1., -0., -1., -1., 0.,
1., -1., -1., 0., 0., -0.],
[ 0., -1., 1., -1., 1., -1., -1., 0., 1., -0., -1., 0., -1., -1.,
-1., 1., 1., 1., 0., 1.],
[-1., -1., -1., 1., 0., 1., -0., 1., -1., -1., 1., -1., 1., 0.,
1., -0., -0., -1., -1., -0.],
[ 1., 1., -1., 0., 1., 1., -1., 1., 1., 1., -1., 1., 1., 1.,
-1., -1., 1., -1., -1., 1.],
[-0., -1., 0., -0., 1., -1., 1., -1., -0., 1., -1., 0., -0., -1.,
0., 0., -0., 1., 1., -1.],
[-0., 1., -0., 1., 1., -0., 1., 1., -0., 0., -1., 1., 1., -0.,
0., -1., 0., -1., -1., 1.],
[ 1., 1., -0., -0., 0., -0., -1., -1., 1., 1., -1., 1., -0., 1.,
-1., -1., 1., 1., -1., 1.],
[-1., 0., 1., 1., -1., 1., 1., 1., -1., -0., 1., 0., 1., 0.,
1., -1., -1., -1., 0., -1.],
[ 1., 1., -1., 1., 1., 0., -1., -0., 1., 1., 0., 1., -0., 1.,
-1., 1., 1., 1., -1., 1.],
[ 0., 0., 1., -1., 0., -1., 1., -1., 0., -1., -1., -1., -0., 1.,
-0., -1., 1., 1., 1., -1.],
[-1., -0., 0., -0., 0., -1., 0., 1., -1., -1., 0., -1., 0., -0.,
1., -0., -1., -0., -1., 1.]], grad_fn=<ClampBackward1>)
Disable the biases, as we don't need them for this task and it makes things easier to interpret.
for name, param in model2.named_parameters():
if "b_" in name:
param.requires_grad = False
Define Optimizer + Loss¶
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)
optimizer2 = torch.optim.AdamW(model2.parameters(), lr=lr, weight_decay=wd, betas=betas)
def loss_fn(logits, labels):
if len(logits.shape)==3:
logits = logits[:, -1]
logits = logits.to(torch.float64)
log_probs = logits.log_softmax(dim=-1)
correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
return -correct_log_probs.mean()
train_logits = model2(train_data)
print(train_logits.shape)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model2(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)
torch.Size([3830, 3, 113]) tensor(4.9126, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>) tensor(4.9016, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
print("Uniform loss:")
print(np.log(p))
Uniform loss: 4.727387818712341
Actually Train¶
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if True:
for epoch in tqdm.tqdm(range(num_epochs)):
train_logits = model2(train_data)
train_loss = loss_fn(train_logits, train_labels)
train_loss.backward()
train_losses.append(train_loss.item())
optimizer2.step()
optimizer2.zero_grad()
with torch.inference_mode():
test_logits = model2(test_data)
test_loss = loss_fn(test_logits, test_labels)
test_losses.append(test_loss.item())
if ((epoch+1)%checkpoint_every)==0:
# checkpoint_epochs.append(epoch)
# model_checkpoints.append(copy.deepcopy(model.state_dict()))
print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")
0%| | 0/25000 [00:00<?, ?it/s]
Epoch 99 Train Loss 3.014027272855008 Test Loss 5.257051266190152 Epoch 199 Train Loss 1.9739505907313932 Test Loss 5.531677587933805 Epoch 299 Train Loss 1.522004087392409 Test Loss 5.727679950013332 Epoch 399 Train Loss 1.1759723615452564 Test Loss 5.746230289352804 Epoch 499 Train Loss 0.9888114672109399 Test Loss 5.71135526991844 Epoch 599 Train Loss 0.8325707587175056 Test Loss 5.490451838947042 Epoch 699 Train Loss 0.6249676259275181 Test Loss 5.237904196133495 Epoch 799 Train Loss 0.5733551656036998 Test Loss 4.871027155777017 Epoch 899 Train Loss 0.5217777205204979 Test Loss 4.557347933970717 Epoch 999 Train Loss 0.41105701614316476 Test Loss 4.228014028216051 Epoch 1099 Train Loss 0.3069356496280045 Test Loss 3.8890438905073847 Epoch 1199 Train Loss 0.3559486168683461 Test Loss 3.6272816045689806 Epoch 1299 Train Loss 0.17511823761329504 Test Loss 3.044170504408936 Epoch 1399 Train Loss 0.17019031260521184 Test Loss 2.772464646641079 Epoch 1499 Train Loss 0.14026314735844433 Test Loss 2.4160397687297093 Epoch 1599 Train Loss 0.053360383805138606 Test Loss 1.6815853693410598 Epoch 1699 Train Loss 0.7633333194656067 Test Loss 2.321986515160476 Epoch 1799 Train Loss 0.04551002885835869 Test Loss 1.048776820589426 Epoch 1899 Train Loss 0.04424727466388738 Test Loss 0.8316317859630389 Epoch 1999 Train Loss 0.06827304201839962 Test Loss 0.6805699728749852 Epoch 2099 Train Loss 0.03050139338488667 Test Loss 0.3492496919758744 Epoch 2199 Train Loss 0.05500260601913215 Test Loss 0.3595923712826434 Epoch 2299 Train Loss 0.051626710729373464 Test Loss 0.24343468779568503 Epoch 2399 Train Loss 0.02496551151420753 Test Loss 0.09500878872958453 Epoch 2499 Train Loss 0.020029444704006683 Test Loss 0.06709763804440591 Epoch 2599 Train Loss 0.17291242313706026 Test Loss 0.2449376673635308 Epoch 2699 Train Loss 0.029277861472561208 Test Loss 0.06484897366728927 Epoch 2799 Train Loss 0.017683699057758046 Test Loss 0.03719545554741101 Epoch 2899 Train Loss 0.015527636602817889 Test Loss 0.0308856454793664 Epoch 2999 Train Loss 0.017477540923563777 Test Loss 0.0343058661757621 Epoch 3099 Train Loss 0.031178583690265677 Test Loss 0.05859885572527605 Epoch 3199 Train Loss 0.01594110454932285 Test Loss 0.028847310455305065 Epoch 3299 Train Loss 0.01130242668187678 Test Loss 0.022109196841548884 Epoch 3399 Train Loss 0.00944754641031782 Test Loss 0.018057409247736633 Epoch 3499 Train Loss 0.011584174750511059 Test Loss 0.022158710509195184 Epoch 3599 Train Loss 0.008612169723591717 Test Loss 0.01578912584261249 Epoch 3699 Train Loss 0.007344236758263225 Test Loss 0.014484335987667327 Epoch 3799 Train Loss 0.03444031943149612 Test Loss 0.07383978481928931 Epoch 3899 Train Loss 0.010508569538739188 Test Loss 0.01700995373317897 Epoch 3999 Train Loss 0.007814717998247374 Test Loss 0.017161369875260644 Epoch 4099 Train Loss 0.014003513053769923 Test Loss 0.022802549129217843 Epoch 4199 Train Loss 0.010540057636575098 Test Loss 0.01951287208329256 Epoch 4299 Train Loss 0.007884919930822219 Test Loss 0.017399847211718923 Epoch 4399 Train Loss 0.008498065431090772 Test Loss 0.01666843373193272 Epoch 4499 Train Loss 0.007268346353339773 Test Loss 0.01609863049238089 Epoch 4599 Train Loss 0.007219136279486196 Test Loss 0.01301786540788075 Epoch 4699 Train Loss 0.015722836921913938 Test Loss 0.029976705333097254 Epoch 4799 Train Loss 1.5173941969885896 Test Loss 0.11050469530151392 Epoch 4899 Train Loss 0.008627219097254031 Test Loss 0.016533455483509832 Epoch 4999 Train Loss 0.005140299896835277 Test Loss 0.010291649303581219 Epoch 5099 Train Loss 0.005807563112341821 Test Loss 0.012234079651903375 Epoch 5199 Train Loss 0.11463886645163 Test Loss 0.25306489987189557 Epoch 5299 Train Loss 0.02987837422638846 Test Loss 0.047481054759677045 Epoch 5399 Train Loss 0.005853009127711766 Test Loss 0.012927247053131337 Epoch 5499 Train Loss 0.007852937098875948 Test Loss 0.014473028724329338 Epoch 5599 Train Loss 0.004771804362387504 Test Loss 0.008844814488952113 Epoch 5699 Train Loss 0.01000941765260776 Test Loss 0.019904032932646545 Epoch 5799 Train Loss 0.007761891992589662 Test Loss 0.025462229198083632 Epoch 5899 Train Loss 0.006010445204117349 Test Loss 0.011230815024870878 Epoch 5999 Train Loss 0.00460445132784219 Test Loss 0.008317255946045246 Epoch 6099 Train Loss 0.018245486825912434 Test Loss 0.0346233052875155 Epoch 6199 Train Loss 0.006030297201884182 Test Loss 0.010148837823812812 Epoch 6299 Train Loss 0.022991812791717036 Test Loss 0.04824160337577981 Epoch 6399 Train Loss 0.006908689015455833 Test Loss 0.01252568055360984 Epoch 6499 Train Loss 0.005204872425746188 Test Loss 0.016231964184172495 Epoch 6599 Train Loss 0.0073347757308894306 Test Loss 0.015181290399762374 Epoch 6699 Train Loss 0.004337612955536499 Test Loss 0.010005724701589792 Epoch 6799 Train Loss 0.007042155948531402 Test Loss 0.013756642238908356 Epoch 6899 Train Loss 0.005260564413017804 Test Loss 0.008705474451209316 Epoch 6999 Train Loss 0.007813149279209603 Test Loss 0.013532182104095446 Epoch 7099 Train Loss 0.00582434950195988 Test Loss 0.011458294197637765 Epoch 7199 Train Loss 0.007865457143934563 Test Loss 0.023307269146775975 Epoch 7299 Train Loss 0.007757095419300518 Test Loss 0.013666326159735788 Epoch 7399 Train Loss 0.0226551796058754 Test Loss 0.36183718808855186 Epoch 7499 Train Loss 0.00564188272034945 Test Loss 0.009695599744040098 Epoch 7599 Train Loss 0.006026714924514862 Test Loss 0.011638525615791756 Epoch 7699 Train Loss 0.0063373832514962035 Test Loss 0.01035254957801649 Epoch 7799 Train Loss 0.005203066518352558 Test Loss 0.011670551309493314 Epoch 7899 Train Loss 0.004240773315091999 Test Loss 0.00821825347638085 Epoch 7999 Train Loss 0.005125417557873587 Test Loss 0.010744485229665517 Epoch 8099 Train Loss 0.005047975896386789 Test Loss 0.008312175003903402 Epoch 8199 Train Loss 0.004483931624401202 Test Loss 0.010477993114286356 Epoch 8299 Train Loss 0.005376747945837409 Test Loss 0.009818523770522083 Epoch 8399 Train Loss 0.03918102075847193 Test Loss 0.31703341750497355 Epoch 8499 Train Loss 0.010159331100201511 Test Loss 0.02736646714860391 Epoch 8599 Train Loss 0.010352373668192963 Test Loss 0.023717785036943682 Epoch 8699 Train Loss 0.00517038615736642 Test Loss 0.012052945665118733 Epoch 8799 Train Loss 0.03188823162116645 Test Loss 0.056357977675609 Epoch 8899 Train Loss 0.006405158320766152 Test Loss 0.012460632656411823 Epoch 8999 Train Loss 0.005797387108335514 Test Loss 0.012299623902828965 Epoch 9099 Train Loss 0.010745934486703248 Test Loss 0.025742348628020446 Epoch 9199 Train Loss 0.006300450497750014 Test Loss 0.011303562490226153 Epoch 9299 Train Loss 0.016549669511934345 Test Loss 0.029058960353692338 Epoch 9399 Train Loss 0.0062494519586099445 Test Loss 0.01095256049133248 Epoch 9499 Train Loss 0.0041762511412836425 Test Loss 0.010237694503575695 Epoch 9599 Train Loss 0.006795560678665745 Test Loss 0.013466958661714835 Epoch 9699 Train Loss 0.0051661260958091775 Test Loss 0.009797268705253761 Epoch 9799 Train Loss 0.005646884489225495 Test Loss 0.010579170587840353 Epoch 9899 Train Loss 0.0057106309443332665 Test Loss 0.010171213196008829 Epoch 9999 Train Loss 0.024083937512403803 Test Loss 0.17206715527668656 Epoch 10099 Train Loss 0.006787476165306507 Test Loss 0.013084884281623793 Epoch 10199 Train Loss 0.015042880665993567 Test Loss 0.03250300900608231 Epoch 10299 Train Loss 0.006661591671519516 Test Loss 0.016151494357086467 Epoch 10399 Train Loss 0.0194474186405305 Test Loss 0.03658243782885725 Epoch 10499 Train Loss 0.006085515210212812 Test Loss 0.011258562133649069 Epoch 10599 Train Loss 0.015268978953551147 Test Loss 0.03982543709386692 Epoch 10699 Train Loss 0.007034749597012071 Test Loss 0.012800186451725874 Epoch 10799 Train Loss 0.033976388582361544 Test Loss 0.054336910680678745 Epoch 10899 Train Loss 0.006644238939328866 Test Loss 0.010306579676647179 Epoch 10999 Train Loss 0.012632614416316867 Test Loss 0.02137456669135697 Epoch 11099 Train Loss 0.008644255526785893 Test Loss 0.013635602435125382 Epoch 11199 Train Loss 0.009002821897536607 Test Loss 0.024283796363440063 Epoch 11299 Train Loss 0.00956866167016917 Test Loss 0.014007167602777744 Epoch 11399 Train Loss 0.006186407543608214 Test Loss 0.020006375797920414 Epoch 11499 Train Loss 0.009066760623116699 Test Loss 0.015562450662062818 Epoch 11599 Train Loss 0.003233468929205958 Test Loss 0.0075408937744635566 Epoch 11699 Train Loss 0.00463057022388752 Test Loss 0.013318069886425194 Epoch 11799 Train Loss 0.006212185301594643 Test Loss 0.014807138306972839 Epoch 11899 Train Loss 0.011230689309893946 Test Loss 0.024266991807788393 Epoch 11999 Train Loss 0.009012550685312171 Test Loss 0.01473597350215947 Epoch 12099 Train Loss 0.012368516566849942 Test Loss 0.02335653563183951 Epoch 12199 Train Loss 0.005360803083873902 Test Loss 0.011250311840534256 Epoch 12299 Train Loss 0.01197070365613383 Test Loss 0.02386086127001253 Epoch 12399 Train Loss 0.005256749311514027 Test Loss 0.010836827698497664 Epoch 12499 Train Loss 0.0037610926356911726 Test Loss 0.01156108322258836 Epoch 12599 Train Loss 0.00454574201847951 Test Loss 0.01273431296359303 Epoch 12699 Train Loss 0.006413512679637891 Test Loss 0.014645961473406643 Epoch 12799 Train Loss 0.0038651724461309053 Test Loss 0.009790047908121529 Epoch 12899 Train Loss 0.006947063237986601 Test Loss 0.012845340324163212 Epoch 12999 Train Loss 0.004688527321565632 Test Loss 0.008962157033684618 Epoch 13099 Train Loss 0.006543164670281356 Test Loss 0.013750982433705615 Epoch 13199 Train Loss 0.00655437703764851 Test Loss 0.013628432598465336 Epoch 13299 Train Loss 0.016486940764630944 Test Loss 0.039337721864650364 Epoch 13399 Train Loss 0.011127610181620824 Test Loss 0.022114027474425587 Epoch 13499 Train Loss 0.007751129221486866 Test Loss 0.01394409887981256 Epoch 13599 Train Loss 0.009725438925188653 Test Loss 0.027738442230004507 Epoch 13699 Train Loss 0.003025566921146319 Test Loss 0.007962626103056368 Epoch 13799 Train Loss 0.00491716145455049 Test Loss 0.014435600406657699 Epoch 13899 Train Loss 0.005324427467547286 Test Loss 0.011707976067772543 Epoch 13999 Train Loss 0.005820595584232503 Test Loss 0.016751721791031175 Epoch 14099 Train Loss 0.00275293836403465 Test Loss 0.006525006514601194 Epoch 14199 Train Loss 0.006635683823195103 Test Loss 0.012448271900777038 Epoch 14299 Train Loss 0.003262449689909741 Test Loss 0.007636566189906532 Epoch 14399 Train Loss 0.004779322772201461 Test Loss 0.011390565156633154 Epoch 14499 Train Loss 0.004204657339831422 Test Loss 0.007923595771545723 Epoch 14599 Train Loss 0.006751773122951093 Test Loss 0.015020885928965012 Epoch 14699 Train Loss 0.003044079654818423 Test Loss 0.008487243820825413 Epoch 14799 Train Loss 0.0049937107088374915 Test Loss 0.012156085623240252 Epoch 14899 Train Loss 0.005133255888945027 Test Loss 0.00961078740912314 Epoch 14999 Train Loss 0.006283490672739902 Test Loss 0.013428526754256457 Epoch 15099 Train Loss 0.006235007696986871 Test Loss 0.013109465444244168 Epoch 15199 Train Loss 0.018205965378314225 Test Loss 0.24987246784411238 Epoch 15299 Train Loss 0.004740745738748825 Test Loss 0.009855295607649212 Epoch 15399 Train Loss 0.015116360519934506 Test Loss 0.031308287451253185 Epoch 15499 Train Loss 0.006095618082309142 Test Loss 0.01357054481086 Epoch 15599 Train Loss 0.005067087954475149 Test Loss 0.011647571160660815 Epoch 15699 Train Loss 0.0065426768209534344 Test Loss 0.013680637324420977 Epoch 15799 Train Loss 0.006811639928967707 Test Loss 0.018083006851543112 Epoch 15899 Train Loss 0.005625372811974596 Test Loss 0.013462300875237813 Epoch 15999 Train Loss 0.004316937476041202 Test Loss 0.009925166825960901 Epoch 16099 Train Loss 0.00401416524189605 Test Loss 0.014005719469449325 Epoch 16199 Train Loss 0.004114078234686338 Test Loss 0.01030211143133201 Epoch 16299 Train Loss 0.0996381708704457 Test Loss 0.6342305354245246 Epoch 16399 Train Loss 0.005119406032693103 Test Loss 0.014100570133946253 Epoch 16499 Train Loss 0.007114649698448621 Test Loss 0.016413036711366312 Epoch 16599 Train Loss 0.016406180486319824 Test Loss 0.07935208282972422 Epoch 16699 Train Loss 0.004178059996827253 Test Loss 0.0075748492515196894 Epoch 16799 Train Loss 0.031014201860431887 Test Loss 0.06199268809418528 Epoch 16899 Train Loss 0.005988712646077473 Test Loss 0.011815517544231763 Epoch 16999 Train Loss 0.005753983495151246 Test Loss 0.01657187724907412 Epoch 17099 Train Loss 0.002962333814098864 Test Loss 0.008640743997377368 Epoch 17199 Train Loss 0.004722244771118769 Test Loss 0.01143650883054368 Epoch 17299 Train Loss 0.003422362539028662 Test Loss 0.008469275617099672 Epoch 17399 Train Loss 0.006074305614278022 Test Loss 0.012443836532947452 Epoch 17499 Train Loss 0.0055509331567981605 Test Loss 0.02617465254170481 Epoch 17599 Train Loss 0.008944180094125074 Test Loss 0.02003147138073522 Epoch 17699 Train Loss 0.007908790059944406 Test Loss 0.018474563630073194 Epoch 17799 Train Loss 0.005836891642128378 Test Loss 0.01293567484576754 Epoch 17899 Train Loss 0.0034583794011619024 Test Loss 0.006637441704863701 Epoch 17999 Train Loss 0.004175517703263331 Test Loss 0.010327551780830653 Epoch 18099 Train Loss 0.0044088925206109165 Test Loss 0.01258115911855302 Epoch 18199 Train Loss 0.007366679493516597 Test Loss 0.013879959226309338 Epoch 18299 Train Loss 0.019816789929061714 Test Loss 0.058044578129494186 Epoch 18399 Train Loss 0.0035267163574848217 Test Loss 0.009559965353275045 Epoch 18499 Train Loss 0.002564155317594407 Test Loss 0.010102166113210586 Epoch 18599 Train Loss 0.005046397331079542 Test Loss 0.011651097779353909 Epoch 18699 Train Loss 0.005905591450805215 Test Loss 0.01832295760120704 Epoch 18799 Train Loss 0.00441335257802122 Test Loss 0.011627443791655022 Epoch 18899 Train Loss 0.01857462052718751 Test Loss 0.03508550459744032 Epoch 18999 Train Loss 0.008711039474951847 Test Loss 0.01369996368239906 Epoch 19099 Train Loss 0.006241759005897646 Test Loss 0.014976209304340897 Epoch 19199 Train Loss 0.006226107788826661 Test Loss 0.012545228046320207 Epoch 19299 Train Loss 0.006152189141209667 Test Loss 0.013927406389270069 Epoch 19399 Train Loss 0.0046844052396551724 Test Loss 0.008837772563474252 Epoch 19499 Train Loss 0.017292456586350206 Test Loss 0.08303805448737266 Epoch 19599 Train Loss 0.007701430377144118 Test Loss 0.01587523413622618 Epoch 19699 Train Loss 0.00879020142471371 Test Loss 0.016510904257221192 Epoch 19799 Train Loss 0.005341619455601752 Test Loss 0.016226363262137448 Epoch 19899 Train Loss 0.00574150153365048 Test Loss 0.012727862218807598 Epoch 19999 Train Loss 0.00447267899145964 Test Loss 0.011248403890541896 Epoch 20099 Train Loss 0.007617898731252999 Test Loss 0.016335767037510805 Epoch 20199 Train Loss 0.0034273121486537265 Test Loss 0.00807494463449596 Epoch 20299 Train Loss 0.02411569706577758 Test Loss 0.13857942957955854 Epoch 20399 Train Loss 0.005447597844414978 Test Loss 0.010788242814720758 Epoch 20499 Train Loss 0.003606701812218689 Test Loss 0.00907461565189944 Epoch 20599 Train Loss 0.0044116508932094305 Test Loss 0.009445801388865652 Epoch 20699 Train Loss 0.004708016471408517 Test Loss 0.011611825999697739 Epoch 20799 Train Loss 0.006882586748304146 Test Loss 0.011037992228144135 Epoch 20899 Train Loss 0.005489512322527175 Test Loss 0.012806646676274116 Epoch 20999 Train Loss 0.00505197605969682 Test Loss 0.014213305649454762 Epoch 21099 Train Loss 0.006137806729254183 Test Loss 0.011744444451134638 Epoch 21199 Train Loss 0.003048332497950223 Test Loss 0.005747565299925699 Epoch 21299 Train Loss 0.005475367583263351 Test Loss 0.012504500177660015 Epoch 21399 Train Loss 0.006453856583490862 Test Loss 0.014674007804229654 Epoch 21499 Train Loss 0.004920337604616995 Test Loss 0.019097202335171464 Epoch 21599 Train Loss 0.00621042870417315 Test Loss 0.012378176815384454 Epoch 21699 Train Loss 0.004692693160576115 Test Loss 0.01167148358610382 Epoch 21799 Train Loss 0.002872270681743068 Test Loss 0.006472253713687538 Epoch 21899 Train Loss 0.0057793089121313465 Test Loss 0.013899451877789066 Epoch 21999 Train Loss 0.0028232826705021003 Test Loss 0.009231328676561087 Epoch 22099 Train Loss 0.0601843648194262 Test Loss 0.11324702480752791 Epoch 22199 Train Loss 0.006265771546887825 Test Loss 0.0092518344690687 Epoch 22299 Train Loss 0.02493718538184975 Test Loss 0.05601161357614526 Epoch 22399 Train Loss 0.02079879972717882 Test Loss 0.03769143448567399 Epoch 22499 Train Loss 0.006400924549352391 Test Loss 0.012071695525412582 Epoch 22599 Train Loss 0.005989753915697728 Test Loss 0.015782807511823302 Epoch 22699 Train Loss 0.003220099370254898 Test Loss 0.007192940402789559 Epoch 22799 Train Loss 0.008073586505535961 Test Loss 0.02079398309069647 Epoch 22899 Train Loss 0.008446202928551968 Test Loss 0.014570606553027093 Epoch 22999 Train Loss 0.008030810104982035 Test Loss 0.01584055666160267 Epoch 23099 Train Loss 0.05218788995517721 Test Loss 0.1034327943093541 Epoch 23199 Train Loss 0.006933598198366225 Test Loss 0.01275343030088009 Epoch 23299 Train Loss 0.0042659809183164124 Test Loss 0.00916716889616842 Epoch 23399 Train Loss 0.0025011444965786053 Test Loss 0.006210916080281773 Epoch 23499 Train Loss 0.0048154087738754075 Test Loss 0.01437713683240506 Epoch 23599 Train Loss 0.004794220469289095 Test Loss 0.010806363936549625 Epoch 23699 Train Loss 0.004166510988173378 Test Loss 0.0208734821523086 Epoch 23799 Train Loss 0.006410555417514106 Test Loss 0.01409433680348442 Epoch 23899 Train Loss 0.0037600294188076065 Test Loss 0.00949583965497295 Epoch 23999 Train Loss 0.00609172470371959 Test Loss 0.012996912670672394 Epoch 24099 Train Loss 0.0038821670583233828 Test Loss 0.008331712074175147 Epoch 24199 Train Loss 0.01773734360064861 Test Loss 0.0368521215883396 Epoch 24299 Train Loss 0.0052598755087730955 Test Loss 0.010078249746767508 Epoch 24399 Train Loss 0.009569066500979174 Test Loss 0.01996766223337111 Epoch 24499 Train Loss 0.008184556934327784 Test Loss 0.016793537404796194 Epoch 24599 Train Loss 0.003773291782596996 Test Loss 0.007851088257004415 Epoch 24699 Train Loss 0.005290844280224686 Test Loss 0.012421840603376014 Epoch 24799 Train Loss 0.003283011769852463 Test Loss 0.008110222298598991 Epoch 24899 Train Loss 0.006452217400060633 Test Loss 0.013658548906698422 Epoch 24999 Train Loss 0.003427175596928741 Test Loss 0.007164795727361532
model2.state_dict
<bound method Module.state_dict of BitNetTransformer(
(emb): Embedding(114, 128)
(transformer): Transformer(
(layers): ModuleList(
(0): BitAttention(
(to_qkv): ModuleList(
(0-2): 3 x BitLinear(in_features=128, out_features=128, bias=False)
)
(to_out): BitLinear(in_features=128, out_features=128, bias=False)
)
)
(ffn_layers): ModuleList(
(0): BitFeedForward(
(layer1): BitLinear(in_features=128, out_features=512, bias=False)
(activation): ReLU()
(layer2): BitLinear(in_features=512, out_features=128, bias=False)
)
)
)
(to_logits): Sequential(
(0): RMSNorm()
(1): Linear(in_features=128, out_features=113, bias=False)
)
)>
Above is the architecture of the model used. We see that all the linear layers in between the embed and unembed are BitLinear. Also note that the RMS norm is added before the unembeding, this was also found to be necessary for the model to actually train.
model2.emb
W_E = model2.emb.weight
W_E = W_E[:-1]
print(W_E.shape)
torch.Size([113, 128])
cache = model2.cache
print(cache.keys())
dict_keys(['attn_pattern_BitAttention', 'pre_activation_BitLinear', 'post_activation_BitLinear'])
torch.save(
{
"model":model.state_dict(),
"config": model.cfg,
"checkpoints": model_checkpoints,
"checkpoint_epochs": checkpoint_epochs,
"test_losses": test_losses,
"train_losses": train_losses,
"train_indices": train_indices,
"test_indices": test_indices,
},
PTH_LOCATION)
if not TRAIN_MODEL:
cached_data = torch.load(PTH_LOCATION)
model.load_state_dict(cached_data['model'])
model_checkpoints = cached_data["checkpoints"]
checkpoint_epochs = cached_data["checkpoint_epochs"]
test_losses = cached_data['test_losses']
train_losses = cached_data['train_losses']
train_indices = cached_data["train_indices"]
test_indices = cached_data["test_indices"]
Show Model Training Statistics, Check that it groks!¶
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line
/usr/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
Collecting git+https://github.com/neelnanda-io/neel-plotly.git Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-85rx6jjh Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-85rx6jjh Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7 Preparing metadata (setup.py) ... done Requirement already satisfied: einops in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from UNKNOWN==0.0.0) (0.7.0) Requirement already satisfied: numpy in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from UNKNOWN==0.0.0) (1.25.2) Requirement already satisfied: pandas in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from UNKNOWN==0.0.0) (2.2.0) Requirement already satisfied: plotly in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from UNKNOWN==0.0.0) (5.18.0) Requirement already satisfied: torch in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from UNKNOWN==0.0.0) (2.2.0) Requirement already satisfied: tqdm in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from UNKNOWN==0.0.0) (4.66.1) Requirement already satisfied: pytz>=2020.1 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from pandas->UNKNOWN==0.0.0) (2024.1) Requirement already satisfied: python-dateutil>=2.8.2 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from pandas->UNKNOWN==0.0.0) (2.8.2) Requirement already satisfied: tzdata>=2022.7 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from pandas->UNKNOWN==0.0.0) (2024.1) Requirement already satisfied: tenacity>=6.2.0 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from plotly->UNKNOWN==0.0.0) (8.2.3) Requirement already satisfied: packaging in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from plotly->UNKNOWN==0.0.0) (23.2) Requirement already satisfied: fsspec in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (2023.10.0) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (12.1.105) Requirement already satisfied: filelock in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (3.13.1) Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (12.1.3.1) Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (11.4.5.107) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (12.1.105) Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (11.0.2.54) Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (12.1.0.106) Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (2.19.3) Requirement already satisfied: triton==2.2.0 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (2.2.0) Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (12.1.105) Requirement already satisfied: jinja2 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (3.1.3) Requirement already satisfied: networkx in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (3.2.1) Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (8.9.2.26) Requirement already satisfied: sympy in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (1.12) Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (10.3.2.106) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (12.1.105) Requirement already satisfied: typing-extensions>=4.8.0 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from torch->UNKNOWN==0.0.0) (4.9.0) Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->UNKNOWN==0.0.0) (12.3.101) Requirement already satisfied: six>=1.5 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->UNKNOWN==0.0.0) (1.16.0) Requirement already satisfied: MarkupSafe>=2.0 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from jinja2->torch->UNKNOWN==0.0.0) (2.1.5) Requirement already satisfied: mpmath>=0.19 in /home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages (from sympy->torch->UNKNOWN==0.0.0) (1.3.0) Note: you may need to restart the kernel to use updated packages.
line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Modular Addition", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)
/home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages/plotly/express/_core.py:2065: FutureWarning: When grouping with a length-1 list-like, you will need to pass a length-1 tuple to get_group in a future version of pandas. Pass `(name,)` instead of `name` to silence this warning.
We see from the training curve that the model does indeed grok. The train is not able to get as low as the full model likely due to the lack of precision afforded by binarization.
# print out all parameters of model2
for name, param in model2.named_parameters():
print(name, param.shape)
emb.weight torch.Size([114, 128]) transformer.layers.0.to_qkv.0.weight torch.Size([128, 128]) transformer.layers.0.to_qkv.1.weight torch.Size([128, 128]) transformer.layers.0.to_qkv.2.weight torch.Size([128, 128]) transformer.layers.0.to_out.weight torch.Size([128, 128]) transformer.ffn_layers.0.layer1.weight torch.Size([512, 128]) transformer.ffn_layers.0.layer2.weight torch.Size([128, 512]) to_logits.0.gamma torch.Size([128]) to_logits.1.weight torch.Size([113, 128])
print(model2.transformer.layers[0].to_qkv[2].weight.size())
print(model2.transformer.layers[0].to_out.weight.size())
print(model2.transformer.ffn_layers[0].layer1.weight.size())
torch.Size([128, 128]) torch.Size([128, 128]) torch.Size([512, 128])
W_V = model2.transformer.layers[0].to_qkv[2].weight
W_O = model2.transformer.layers[0].to_out.weight
W_mlp_in = model2.transformer.ffn_layers[0].layer1.weight
W_mlp_in
Parameter containing:
tensor([[-8.4774e-04, 1.1625e-03, -2.1105e-02, ..., -2.7192e-02,
-2.1693e-03, -3.7147e-03],
[ 8.8908e-03, 9.3728e-03, 6.3379e-03, ..., -1.9593e-02,
-3.2319e-04, 1.5677e-03],
[-1.6904e-03, 1.6420e-02, -2.3953e-03, ..., 1.7269e-02,
7.4890e-03, 2.6184e-03],
...,
[ 2.3208e-02, 3.9736e-04, -4.8651e-02, ..., 1.4207e-02,
2.8997e-03, -4.8819e-03],
[ 1.8384e-02, -1.0097e-02, -7.1043e-03, ..., 2.0025e-03,
8.4162e-03, 2.2377e-03],
[-3.3788e-03, -2.0334e-03, 9.0140e-05, ..., -2.3906e-02,
1.8399e-02, 5.2255e-03]], device='cuda:0', requires_grad=True)
torch.sign(W_mlp_in) @ torch.sign(W_O) @ torch.sign(W_V)
tensor([[-450., -126., 394., ..., -98., 70., 354.],
[-122., 130., 202., ..., -106., -194., -102.],
[-342., 510., 26., ..., -150., -438., -914.],
...,
[-102., -2., 138., ..., 178., -182., -226.],
[ 374., -150., 122., ..., 110., -26., -38.],
[-110., -210., 686., ..., -134., -46., -658.]], device='cuda:0',
grad_fn=<MmBackward0>)
import plotly.express as px
# Convert the tensor to numpy and then to int for visualization
W_V_sign_np = (torch.sign(W_mlp_in)).detach().cpu().numpy().astype(int)
# Create a black and white color scale
colorscale = [[0, 'white'], [1, 'black']]
# Create the image
fig = px.imshow(W_V_sign_np, color_continuous_scale=colorscale, range_color=[-1,1])
# add a title
fig.update_layout(title="Binarized W_V visualized")
# Show the image
fig.show()
We see that there does not exist any clear patterns to be discerned when we visualize one of the binarized weight matrices.
Analysing the Model¶
Standard Things to Try¶
original_logits= model2(dataset)
print(original_logits.numel())
cache = model2.cache
print(cache)
4328691
{'attn_pattern_BitAttention': tensor([[[[3.1628e-01, 3.1628e-01, 3.6744e-01],
[3.1628e-01, 3.1628e-01, 3.6744e-01],
[5.0000e-01, 5.0000e-01, 7.6192e-08]],
[[3.1232e-01, 3.1232e-01, 3.7536e-01],
[3.1232e-01, 3.1232e-01, 3.7536e-01],
[5.0000e-01, 5.0000e-01, 2.1801e-09]],
[[3.2163e-01, 3.2163e-01, 3.5674e-01],
[3.2163e-01, 3.2163e-01, 3.5674e-01],
[5.0000e-01, 5.0000e-01, 7.8238e-07]],
[[3.6129e-01, 3.6129e-01, 2.7742e-01],
[3.6129e-01, 3.6129e-01, 2.7742e-01],
[4.9985e-01, 4.9985e-01, 2.9209e-04]]],
[[[3.2695e-01, 2.9321e-01, 3.7984e-01],
[3.2993e-01, 3.0902e-01, 3.6105e-01],
[1.1135e-05, 9.9999e-01, 1.6968e-12]],
[[3.0080e-01, 3.3768e-01, 3.6152e-01],
[3.1359e-01, 3.3617e-01, 3.5024e-01],
[9.9999e-01, 5.5169e-06, 4.3602e-09]],
[[3.2203e-01, 3.2078e-01, 3.5719e-01],
[3.2125e-01, 3.2148e-01, 3.5727e-01],
[4.5999e-01, 5.4001e-01, 7.1978e-07]],
[[3.5946e-01, 3.6453e-01, 2.7601e-01],
[3.0628e-01, 3.0515e-01, 3.8857e-01],
[5.3847e-01, 4.6122e-01, 3.1466e-04]]],
[[[3.1249e-01, 3.2448e-01, 3.6303e-01],
[2.9919e-01, 3.1824e-01, 3.8257e-01],
[9.8096e-01, 1.9042e-02, 1.4948e-07]],
[[3.1571e-01, 3.0486e-01, 3.7943e-01],
[3.0338e-01, 2.8648e-01, 4.1014e-01],
[2.5107e-02, 9.7489e-01, 1.0947e-10]],
[[3.2206e-01, 3.2073e-01, 3.5721e-01],
[3.1548e-01, 3.1923e-01, 3.6529e-01],
[5.5609e-01, 4.4391e-01, 8.7014e-07]],
[[3.5604e-01, 3.7057e-01, 2.7339e-01],
[2.9086e-01, 2.9668e-01, 4.1246e-01],
[5.4626e-01, 4.5342e-01, 3.1921e-04]]],
...,
[[[3.0552e-01, 3.2013e-01, 3.7435e-01],
[3.3510e-01, 3.3411e-01, 3.3079e-01],
[9.9706e-01, 2.9432e-03, 1.0025e-11]],
[[3.2921e-01, 3.1428e-01, 3.5651e-01],
[3.3357e-01, 3.3456e-01, 3.3188e-01],
[3.0679e-03, 9.9693e-01, 1.4957e-07]],
[[3.3452e-01, 3.3420e-01, 3.3128e-01],
[3.2990e-01, 3.3133e-01, 3.3877e-01],
[6.1166e-01, 3.8834e-01, 6.4634e-07]],
[[2.2813e-01, 2.3876e-01, 5.3311e-01],
[3.0627e-01, 3.1411e-01, 3.7962e-01],
[4.2445e-01, 5.7519e-01, 3.5980e-04]]],
[[[2.9248e-01, 3.4914e-01, 3.5838e-01],
[1.9432e-01, 3.8272e-01, 4.2296e-01],
[1.0000e+00, 2.6060e-10, 1.0055e-11]],
[[3.4256e-01, 2.8647e-01, 3.7097e-01],
[3.4954e-01, 1.7629e-01, 4.7417e-01],
[2.0963e-10, 1.0000e+00, 1.0220e-14]],
[[3.3123e-01, 3.4076e-01, 3.2801e-01],
[2.8515e-01, 3.0490e-01, 4.0996e-01],
[7.1983e-01, 2.8017e-01, 7.6064e-07]],
[[2.2877e-01, 2.3663e-01, 5.3460e-01],
[2.9951e-01, 3.0904e-01, 3.9145e-01],
[4.0025e-01, 5.9941e-01, 3.3929e-04]]],
[[[3.1005e-01, 3.1005e-01, 3.7990e-01],
[3.1005e-01, 3.1005e-01, 3.7990e-01],
[5.0000e-01, 5.0000e-01, 5.0273e-12]],
[[3.2437e-01, 3.2437e-01, 3.5127e-01],
[3.2437e-01, 3.2437e-01, 3.5127e-01],
[4.9999e-01, 4.9999e-01, 2.4376e-05]],
[[3.3441e-01, 3.3441e-01, 3.3117e-01],
[3.3441e-01, 3.3441e-01, 3.3117e-01],
[5.0000e-01, 5.0000e-01, 5.2835e-07]],
[[2.3058e-01, 2.3058e-01, 5.3884e-01],
[2.3058e-01, 2.3058e-01, 5.3884e-01],
[4.9979e-01, 4.9979e-01, 4.2367e-04]]]], device='cuda:0'), 'pre_activation_BitLinear': tensor([[[ 106.4083, -51.5264, -66.1423, ..., 118.3598, -169.8047,
-13.1684],
[ 106.4083, -51.5264, -66.1423, ..., 118.3598, -169.8047,
-13.1684],
[ 99.2299, -27.9482, -88.6253, ..., 35.5822, -172.9445,
-163.3766]],
[[ 107.8512, -57.9051, -91.8784, ..., 78.5897, -142.8357,
49.7473],
[ 85.4390, -70.4141, -106.5268, ..., 59.7271, -146.2981,
68.0476],
[ 110.0683, 18.9105, -169.0832, ..., -39.0251, -149.7223,
-117.6599]],
[[ 94.1846, -56.8763, -13.4679, ..., 111.3792, -159.8546,
59.2746],
[ 70.3230, -69.5398, -27.9997, ..., 96.1940, -158.7236,
86.1435],
[ 36.1063, -132.2721, 22.5842, ..., -122.2332, -161.5067,
-90.5593]],
...,
[[ 77.0760, -76.4591, -54.7987, ..., 44.9784, -135.3834,
90.1697],
[ 114.0575, -60.9497, -39.9740, ..., 67.2400, -136.3208,
64.6253],
[ 182.2187, -63.4982, -27.7316, ..., 157.1467, -161.4402,
-15.8297]],
[[ 66.8944, -79.9402, -40.0839, ..., 15.5541, -102.0210,
162.0396],
[ 106.8590, -69.7901, -25.2819, ..., 55.1248, -96.3322,
142.7734],
[ 142.4884, -44.6055, 32.8400, ..., 20.4753, -117.8544,
97.6316]],
[[ 68.6475, -64.9029, 37.6709, ..., 46.7013, -147.7025,
112.0250],
[ 68.6475, -64.9029, 37.6709, ..., 46.7013, -147.7025,
112.0250],
[ 134.6232, -9.4157, 118.7221, ..., -10.6936, -135.5985,
-25.6514]]], device='cuda:0'), 'post_activation_BitLinear': tensor([[[106.4083, 0.0000, 0.0000, ..., 118.3598, 0.0000, 0.0000],
[106.4083, 0.0000, 0.0000, ..., 118.3598, 0.0000, 0.0000],
[ 99.2299, 0.0000, 0.0000, ..., 35.5822, 0.0000, 0.0000]],
[[107.8512, 0.0000, 0.0000, ..., 78.5897, 0.0000, 49.7473],
[ 85.4390, 0.0000, 0.0000, ..., 59.7271, 0.0000, 68.0476],
[110.0683, 18.9105, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 94.1846, 0.0000, 0.0000, ..., 111.3792, 0.0000, 59.2746],
[ 70.3230, 0.0000, 0.0000, ..., 96.1940, 0.0000, 86.1435],
[ 36.1063, 0.0000, 22.5842, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 77.0760, 0.0000, 0.0000, ..., 44.9784, 0.0000, 90.1697],
[114.0575, 0.0000, 0.0000, ..., 67.2400, 0.0000, 64.6253],
[182.2187, 0.0000, 0.0000, ..., 157.1467, 0.0000, 0.0000]],
[[ 66.8944, 0.0000, 0.0000, ..., 15.5541, 0.0000, 162.0396],
[106.8590, 0.0000, 0.0000, ..., 55.1248, 0.0000, 142.7734],
[142.4884, 0.0000, 32.8400, ..., 20.4753, 0.0000, 97.6316]],
[[ 68.6475, 0.0000, 37.6709, ..., 46.7013, 0.0000, 112.0250],
[ 68.6475, 0.0000, 37.6709, ..., 46.7013, 0.0000, 112.0250],
[134.6232, 0.0000, 118.7221, ..., 0.0000, 0.0000, 0.0000]]],
device='cuda:0')}
Get key weight matrices:
W_E = model2.emb.weight
original_loss = loss_fn(original_logits, labels).item()
print("Original Loss:", original_loss)
Original Loss: 0.006048389894853787
Looking at Activations¶
Helper variable:
neuron_acts = cache["post_activation_BitLinear"][:, -1, :]
neuron_pre_acts = cache["pre_activation_BitLinear"][:, -1, :]
print(neuron_acts.size())
torch.Size([12769, 512])
Get all shapes:
for param_name, param in cache.items():
print(param_name, param.shape)
attn_pattern_BitAttention torch.Size([12769, 4, 3, 3]) pre_activation_BitLinear torch.Size([12769, 3, 512]) post_activation_BitLinear torch.Size([12769, 3, 512])
imshow(cache["attn_pattern_BitAttention"].mean(dim=0)[:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])
dataset[:4]
tensor([[ 0, 0, 113],
[ 0, 1, 113],
[ 0, 2, 113],
[ 0, 3, 113]], device='cuda:0')
cache["attn_pattern_BitAttention"].shape
torch.Size([12769, 4, 3, 3])
imshow(cache["attn_pattern_BitAttention"][:, 0, -1, 0].reshape(p, p), title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a")
imshow(
einops.rearrange(cache["attn_pattern_BitAttention"][:, :, -1, 0], "(a b) head -> head a b", a=p, b=p),
title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a", facet_col=0)
Plotting neuron activations
imshow(
einops.rearrange(neuron_pre_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)
Singular Value Decomposition¶
W_E.shape
# take off the last row
W_E = W_E[:-1]
W_E.shape
torch.Size([113, 128])
U, S, Vh = torch.svd(W_E)
line(S, title="Singular Values")
imshow(U, title="Principal Components on the Input")
One difference this has compared to the full precision grokked model is that there seems to be more components.
# Control - random Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))
line(S, title="Singular Values Random")
imshow(U, title="Principal Components Random")
Explaining Algorithm¶
Analyse the Embedding - It's a Lookup Table!¶
U, S, Vh = torch.svd(W_E)
line(U[:, :15].T, title="Principal Components of the embedding", xaxis="Input Vocabulary")
/home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages/plotly/express/_core.py:2065: FutureWarning: When grouping with a length-1 list-like, you will need to pass a length-1 tuple to get_group in a future version of pandas. Pass `(name,)` instead of `name` to silence this warning.
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")
for freq in range(1, p//2+2):
fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
fourier_basis_names.append(f"Sin {freq}")
fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
fourier_basis_names.append(f"Cos {freq}")
fourier_basis = torch.stack(fourier_basis, dim=0).cuda()
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)
imshow(fourier_basis, xaxis="Input", yaxis="Component", y=fourier_basis_names)
line(fourier_basis[:8], xaxis="Input", line_labels=fourier_basis_names[:8], title="First 8 Fourier Components")
line(fourier_basis[25:29], xaxis="Input", line_labels=fourier_basis_names[25:29], title="Middle Fourier Components")
/home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages/plotly/express/_core.py:2065: FutureWarning: When grouping with a length-1 list-like, you will need to pass a length-1 tuple to get_group in a future version of pandas. Pass `(name,)` instead of `name` to silence this warning.
/home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages/plotly/express/_core.py:2065: FutureWarning: When grouping with a length-1 list-like, you will need to pass a length-1 tuple to get_group in a future version of pandas. Pass `(name,)` instead of `name` to silence this warning.
imshow(fourier_basis @ fourier_basis.T, title="All Fourier Vectors are Orthogonal")
Analyse the Embedding¶
imshow(fourier_basis @ W_E, yaxis="Fourier Component", xaxis="Residual Stream", y=fourier_basis_names, title="Embedding in Fourier Basis")
Once again this is alot less "clean" compared to the full precision model but still seems to be fundamentally the same thing.
import plotly.graph_objects as go
# Compute the norm
norm_values = (fourier_basis @ W_E).norm(dim=-1)
# Convert the tensor to a numpy array for plotting
norm_values_np = norm_values.detach().cpu().numpy()
# Create the plot
fig = go.Figure(data=go.Scatter(y=norm_values_np))
# Set the title and labels
fig.update_layout(title='Norm of Tensor', xaxis_title='Index', yaxis_title='Norm')
# Show the plot
fig.show()
line((fourier_basis @ W_E).norm(dim=-1), xaxis="Fourier Component", x=fourier_basis_names, title="Norms of Embedding in Fourier Basis")
# key_freqs = [17, 25, 32, 47]
key_freq_indices = [9,10,23,24,59,60,65,66,67,68,79,80,97,98]
fourier_embed = fourier_basis @ W_E
key_fourier_embed = fourier_embed[key_freq_indices]
print("key_fourier_embed", key_fourier_embed.shape)
imshow(key_fourier_embed @ key_fourier_embed.T, title="Dot Product of embedding of key Fourier Terms")
key_fourier_embed torch.Size([14, 128])
One difference this graph shows compared to the full precision model is that the terms are not as orthogonal. I hypothesize that this is due to the lack of precision of binarization.
Key Frequencies¶
import neel_plotly as npx
key_cos = [num for num in key_freq_indices if num % 2 == 0]
npx.line(fourier_basis[key_cos], title="Cos of key freqs")
/home/jason/projects/ARENA_3.0/myenv/lib/python3.10/site-packages/plotly/express/_core.py:2065: FutureWarning: When grouping with a length-1 list-like, you will need to pass a length-1 tuple to get_group in a future version of pandas. Pass `(name,)` instead of `name` to silence this warning.
npx.line(fourier_basis[key_cos].mean(0), title="Constructive Interference")
Analyse Neurons¶
imshow(
einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)
imshow(
einops.rearrange(neuron_acts[:, 0], "(a b) -> a b", a=p, b=p),
title="First neuron act", xaxis="b", yaxis="a",)
imshow(fourier_basis @ neuron_acts[:, 0].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 0", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)
imshow(fourier_basis @ neuron_acts[:, 5].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 5", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)
Neuron Clusters¶
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)
fourier_neuron_acts torch.Size([512, 115, 115])
neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()
for freq in range(0, p//2):
for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]
imshow(neuron_freq_norm, xaxis="Neuron", yaxis="Freq", y=torch.arange(1, p//2+1), title="Neuron Frac Explained by Freq")
line(neuron_freq_norm.max(dim=0).values.sort().values, xaxis="Neuron", title="Max Neuron Frac Explained over Freqs")
Summary of Results¶
Overall, all of the graphs I've generated seem to align with the graphs of the full precision models from the original reverse engineering modular addition code from https://youtu.be/o0FppeD_xXQ?si=ObA2aISAUQI_H2GC
While this investigation is not exactly thorough, it is pretty clear at least that there is no evidence to support my initial hypothesis that binarized transformers can learn a more discretized and more interpretable representation. Instead, all of evidence seems to suggest instead that the binarized setup is instead learning an algorithm which is mostly the same as the one being learned by the full precision model.
From this preliminary investigation, I further hypothesize that this result is due to the fact that the start and end with the embed and unembed layers are not binarized so they are still free to learn the fourier transform which are the most important parts to this. Furthermore, I think that in general, binarized networks will end up learning approximations of full precision networks. This is because the optimization techniques used such as the straight-through estimator of the gradient used by BitNet aim to treat the binarization mechanism as a continuous function to be optimized over.